#!/usr/bin/env python

"""
This script parses the regular read-based output of a kraken2 run, producing a table of full standard taxonomic lineages of all
classifications (filling in NAs for lower ranks), with counts of how many reads went to that specific taxon. Depends on another `bit` program
and taxonkit, if `bit` was installed with conda, all should be swell :)

Input table expected to look like this:

U  A00159:145:H75T2DMXX:1:1101:7735:13792   unclassified (taxid 0)    16   0:0
U  A00159:145:H75T2DMXX:1:1101:11216:13557  unclassified (taxid 0)    30   0:0
U  A00159:145:H75T2DMXX:1:1101:22688:14074  unclassified (taxid 0)    26   0:0
U  A00159:145:H75T2DMXX:1:1101:1325:14559   unclassified (taxid 0)    31   0:0
U  A00159:145:H75T2DMXX:1:1101:23719:15013  unclassified (taxid 0)    30   0:0
C  A00159:145:H75T2DMXX:1:1102:11388:8312   Ochrobactrum (taxid 528)  194  0:12 1224:16 28211:7 528:15
U  A00159:145:H75T2DMXX:1:1102:15465:8390   unclassified (taxid 0)    27   0:0
U  A00159:145:H75T2DMXX:1:1102:6343:7560    unclassified (taxid 0)    271  0:237
U  A00159:145:H75T2DMXX:1:1102:30101:11600  unclassified (taxid 0)    26   0:0
U  A00159:145:H75T2DMXX:1:1101:19678:2221   unclassified (taxid 0)    279  0:245

Unclassified are reported on one row with "Unclassified" specified at each rank. If names are included, like the example above, you would need to add the '--names-included' flag when running the program.
"""

import sys
import argparse
import pandas as pd
import subprocess
import os
import re

parser = argparse.ArgumentParser(description="This script parses the regular read-based output of a kraken2 run, producing a table of full standard taxonomic lineages of all \
                                              classifications (filling in NAs for lower ranks), with counts of how many reads went to that specific taxon. For version info, run `bit-version`.")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-tsv", help="Input table produced by kraken2 run", action="store", dest="input_file", required=True)
parser.add_argument("-o", "--output-table", help='Output table name (default: "output.tsv")', action="store", dest="output_file", default="output.tsv")
parser.add_argument("--names-included", help='Add this flag if kraken2 was run with the `--use-names` flag', action="store_true")

if len(sys.argv)==1:
  parser.print_help(sys.stderr)
  sys.exit(0)

args = parser.parse_args()

# initializing stuff
unclassified_count = 0
taxid_counts_dict = {}

# iterating through read classifications
with open(args.input_file, "r") as classifications:
    for line in classifications:
        
        # adding to unclassified count and moving on if unclassified
        if line.startswith("U"):
            unclassified_count += 1
            continue
        
        # gettig taxid classification of current read 
        if args.names_included:
            classification = line.strip().split("\t")[2]
            taxid = re.split('\(taxid ', classification)[1].rstrip(")")
        else:
            taxid = line.strip().split("\t")[2]
            
        # adding count to taxid in taxid dictionary if present
        if taxid in taxid_counts_dict:
            taxid_counts_dict[taxid] += 1
            
        # adding to taxid dictionary if current taxid not yet present
        else:
            taxid_counts_dict[taxid] = 1

# getting standard lineage for each taxid
    # writing out taxids to temp file
with open("bit-convert-kraken2.tmp", "w") as tmp_taxid_file:
    for key in taxid_counts_dict:
        tmp_taxid_file.write(str(key) + "\n")

# getting taxid lineages
running_bit_get_lineages_from_taxids = subprocess.run(["bit-get-lineage-from-taxids", "-i", "bit-convert-kraken2.tmp", "-o", "bit-convert-kraken2-lineages.tmp"], stdout=subprocess.DEVNULL)

# reading in results as table
lineage_tab = pd.read_csv("bit-convert-kraken2-lineages.tmp", sep="\t")

# converting taxid dict to dataframe
taxid_counts_df = pd.DataFrame.from_dict(taxid_counts_dict, orient="index").reset_index()
 # moving index to column and setting column names
taxid_counts_df.rename(columns={"index":"taxid", 0:"read_counts"}, inplace=True)

 # setting to integer type so can be merged with lineage tab
taxid_counts_df = taxid_counts_df.astype({"taxid":'int64'})

# merging
combined_tab = lineage_tab.merge(taxid_counts_df).fillna("NA")

# adding in unclassified row
    # pandas append deprecated and dropped as of pandas 2.0, using concat below
# combined_tab = combined_tab.append({"taxid":0, "domain":"Unclassified", "phylum":"Unclassified", "class":"Unclassified", "order":"Unclassified", "family":"Unclassified", "genus":"Unclassified", "species":"Unclassified", "read_counts":unclassified_count}, ignore_index=True)
unclassified_dict = {"taxid":0, "domain":"Unclassified", "phylum":"Unclassified", "class":"Unclassified", "order":"Unclassified", "family":"Unclassified", "genus":"Unclassified", "species":"Unclassified", "read_counts":unclassified_count}
unclassified_tab = pd.DataFrame(unclassified_dict, index = [0]).reset_index()
# dropping index column
unclassified_tab = unclassified_tab.drop("index", axis = "columns")

combined_tab = pd.concat([combined_tab, unclassified_tab], ignore_index = True)

# adding in percent column
combined_tab['percent_of_reads'] = combined_tab.read_counts / combined_tab.read_counts.sum() * 100

# sorting
combined_tab.sort_values(by=["taxid"], inplace=True)

# writing out
combined_tab.to_csv(args.output_file, sep="\t", header=True, index=False)

# removing intermediate files
os.remove("bit-convert-kraken2.tmp")
os.remove("bit-convert-kraken2-lineages.tmp")
